/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysds.runtime.instructions.spark;


import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.PlusMultiply;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import scala.Tuple2;

public class CumulativeAggregateSPInstruction extends AggregateUnarySPInstruction {

	private CumulativeAggregateSPInstruction(AggregateUnaryOperator op, CPOperand in1, CPOperand out, String opcode, String istr) {
		super(SPType.CumsumAggregate, op, null, in1, out, null, opcode, istr);
	}

	public static CumulativeAggregateSPInstruction parseInstruction( String str ) {
		String[] parts = InstructionUtils.getInstructionPartsWithValueType( str );
		InstructionUtils.checkNumFields ( parts, 2 );
		String opcode = parts[0];
		CPOperand in1 = new CPOperand(parts[1]);
		CPOperand out = new CPOperand(parts[2]);
		AggregateUnaryOperator aggun = InstructionUtils.parseCumulativeAggregateUnaryOperator(opcode);
		return new CumulativeAggregateSPInstruction(aggun, in1, out, opcode, str);	
	}
	
	@Override
	public void processInstruction(ExecutionContext ec) {
		SparkExecutionContext sec = (SparkExecutionContext)ec;
		DataCharacteristics mc = sec.getDataCharacteristics(input1.getName());
		DataCharacteristics mcOut = new MatrixCharacteristics(mc);
		long rlen = mc.getRows();
		int blen = mc.getBlocksize();
		mcOut.setRows((long)(Math.ceil((double)rlen/blen)));
		
		//get input
		JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() );
		
		//execute unary aggregate (w/ implicit drop correction)
		AggregateUnaryOperator auop = (AggregateUnaryOperator) _optr;
		JavaPairRDD<MatrixIndexes,MatrixBlock> out = 
			in.mapToPair(new RDDCumAggFunction(auop, rlen, blen));
		//merge partial aggregates, adjusting for correct number of partitions
		//as size can significant shrink (1K) but also grow (sparse-dense)
		int numParts = SparkUtils.getNumPreferredPartitions(mcOut);
		int minPar = (int)Math.min(SparkExecutionContext.getDefaultParallelism(true), mcOut.getNumBlocks());
		out = RDDAggregateUtils.mergeByKey(out, Math.max(numParts, minPar), false);
		
		//put output handle in symbol table
		sec.setRDDHandleForVariable(output.getName(), out);
		sec.addLineageRDD(output.getName(), input1.getName());
		sec.getDataCharacteristics(output.getName()).set(mcOut);
	}

	private static class RDDCumAggFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> 
	{
		private static final long serialVersionUID = 11324676268945117L;
		
		private final AggregateUnaryOperator _op;
		private UnaryOperator _uop = null;
		private final long _rlen;
		private final int _blen;
		
		public RDDCumAggFunction( AggregateUnaryOperator op, long rlen, int blen ) {
			_op = op;
			_rlen = rlen;
			_blen = blen;
		}
		
		@Override
		public Tuple2<MatrixIndexes, MatrixBlock> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 ) 
			throws Exception 
		{
			MatrixIndexes ixIn = arg0._1();
			MatrixBlock blkIn = arg0._2();

			MatrixIndexes ixOut = new MatrixIndexes();
			MatrixBlock blkOut = new MatrixBlock();
			
			//process instruction
			AggregateUnaryOperator aop = _op;
			if( aop.aggOp.increOp.fn instanceof PlusMultiply ) { //cumsumprod
				aop.indexFn.execute(ixIn, ixOut);
				if( _uop == null )
					_uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+*"));
				MatrixBlock t1 = blkIn.unaryOperations(_uop, new MatrixBlock());
				MatrixBlock t2 = blkIn.slice(0, blkIn.getNumRows()-1, 1, 1, new MatrixBlock());
				blkOut.reset(1, 2);
				blkOut.set(0, 0, t1.get(t1.getNumRows()-1, 0));
				blkOut.set(0, 1, t2.prod());
			}
			else { //general case
				OperationsOnMatrixValues.performAggregateUnary( ixIn, blkIn, ixOut, blkOut, aop, _blen);
				if( aop.aggOp.existsCorrection() )
					blkOut.dropLastRowsOrColumns(aop.aggOp.correction);
			}
			
			//cumsum expand partial aggregates
			long rlenOut = (long)Math.ceil((double)_rlen/_blen);
			long rixOut = (long)Math.ceil((double)ixIn.getRowIndex()/_blen);
			int rlenBlk = (int) Math.min(rlenOut-(rixOut-1)*_blen, _blen);
			int clenBlk = blkOut.getNumColumns();
			int posBlk = (int) ((ixIn.getRowIndex()-1) % _blen);
			
			//construct sparse output blocks (single row in target block size)
			MatrixBlock blkOut2 = new MatrixBlock(rlenBlk, clenBlk, true);
			blkOut2.copy(posBlk, posBlk, 0, clenBlk-1, blkOut, true);
			ixOut.setIndexes(rixOut, ixOut.getColumnIndex());
			
			//output new tuple
			return new Tuple2<>(ixOut, blkOut2);
		}
	}
}
